Generative Adversarial Networks (GAN)


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

Source

  • CS231n: CNN for Visual Recognition

1. Discriminative Model v.s. Generative Model

  • Discriminative model




  • Cenerative model



2. Density Function Estimation

  • Probability
  • What if $x$ is actual images in the training data? At this point, $x$ can be represented as a (for example) $64\times 64 \times 3$ dimensional vector.
    • the following images are some realizations (samples) of $64\times 64 \times 3$ dimensional space
  • Probability density function estimation problem
  • If $P_{\text{model}}(x)$ can be estimated as close to $P_{\text{data}}(x)$, then data can be generated by sampling from $P_{\text{model}}(x)$.

    • Note: Kullback–Leibler Divergence is a kind of distance measure between two distributions
  • Learn determinstic transformation via a neural network
    • Start by sampling the code vector $z$ from a simple, fixed distribution such as a uniform distribution or a standard Gaussian $\mathcal{N}(0,I)$
    • Then this code vector is passed as input to a deterministic generator network $G$, which produces an output sample $x=G(z)$
    • This is how a neural network plays in a generative model (as a nonlinear mapping to a target probability density function)



- An example of a generator network which encodes a univariate distribution with two different modes

  • Generative model of high dimensional space
  • Generative model of images
    • learn a function which maps independent, normally-distributed $z$ values to whatever latent variables might be needed to the model, and then map those latent variables to $x$ (as images)
    • first few layers to map the normally distributed $z$ to the latent values
    • then, use later layers to map those latent values to an image



3. Generative Adversarial Networks (GAN)

  • In generative modeling, we'd like to train a network that models a distribution, such as a distribution over images.

  • GANs do not work with any explicit density function !

  • Instead, take game-theoretic approach

3.1. Adversarial Nets Framework

  • One way to judge the quality of the model is to sample from it.

  • Model to produce samples which are indistinguishable from the real data, as judged by a discriminator network whose job is to tell real from fake





  • The idea behind Generative Adversarial Networks (GANs): train two different networks


  • Discriminator network: try to distinguish between real and fake data


  • Generator network: try to produce realistic-looking samples to fool the discriminator network


3.2. Objective Function of GAN

  • Think about a logistic regression classifier (or cross entropy loss $(h(x),y)$)


$$\text{loss} = -y \log h(x) - (1-y) \log (1-h(x))$$

  • To train the discriminator


  • To train the generator


  • Non-Saturating Game when the generator is trained

    • Early in learning, when $G$ is poor, $D$ can reject samples with high confidence because they are clearly different from the training data. In this case, $\log(1-D(G(z)))$ saturates.

    • Rather than training $G$ to minimize $\log(1-D(G(z)))$ we can train $G$ to maximize $\log D(G(z))$. This objective function provides much stronger gradients early in learning.

3.3. Soving a MinMax Problem


Step 1: Fix $G$ and perform a gradient step to


$$\max_{D} E_{x \sim p_{\text{data}}(x)}\left[\log D(x)\right] + E_{x \sim p_{z}(z)}\left[\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to


$$\max_{G} E_{x \sim p_{z}(z)}\left[\log D(G(z))\right]$$

OR



Step 1: Fix $G$ and perform a gradient step to


$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to


$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$

4. GAN with MNIST

4.1. GAN Implementation

In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
WARNING:tensorflow:From <ipython-input-1-ba3428f2355e>:6: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
In [2]:
n_D_input = 28*28
n_D_hidden = 256
n_D_output = 1

n_G_input = 100
n_G_hidden = 256
n_G_output = 28*28
In [3]:
weights = {
    'G1' : tf.Variable(tf.random_normal([n_G_input, n_G_hidden], stddev = 0.01)),
    'G2' : tf.Variable(tf.random_normal([n_G_hidden, n_G_output], stddev = 0.01)),
    'D1' : tf.Variable(tf.random_normal([n_D_input, n_D_hidden], stddev = 0.01)),
    'D2' : tf.Variable(tf.random_normal([n_D_hidden, n_D_output], stddev = 0.01))
}

biases = {
    'G1' : tf.Variable(tf.zeros([n_G_hidden])),
    'G2' : tf.Variable(tf.zeros([n_G_output])),
    'D1' : tf.Variable(tf.zeros([n_D_hidden])),
    'D2' : tf.Variable(tf.zeros([n_D_output]))
}

z = tf.placeholder(tf.float32, [None, n_G_input])
x = tf.placeholder(tf.float32, [None, n_D_input])
In [4]:
def generator(G_input, weights, biases):
    hidden = tf.nn.relu(tf.matmul(G_input, weights['G1']) + biases['G1'])
    output = tf.nn.sigmoid(tf.matmul(hidden, weights['G2']) + biases['G2'])
    return output
In [5]:
def discriminator(D_input, weights, biases):
    hidden = tf.nn.relu(tf.matmul(D_input, weights['D1']) + biases['D1'])
    output = tf.nn.sigmoid(tf.matmul(hidden, weights['D2']) + biases['D2'])
    return output
In [6]:
def make_noise(n_batch, n_G_input):
    return np.random.normal(size = (n_batch, n_G_input))
In [7]:
G_output = generator(z, weights, biases)

D_fake = discriminator(G_output, weights, biases)
D_real = discriminator(x, weights, biases)

Step 1: Fix $G$ and perform a gradient step to

$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to

$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$
In [8]:
D_loss = tf.reduce_mean(- tf.log(D_real) - tf.log(1 - D_fake))
G_loss = tf.reduce_mean(- tf.log(D_fake))
In [9]:
D_var_list = [weights['D1'], biases['D1'], weights['D2'], biases['D2']]
G_var_list = [weights['G1'], biases['G1'], weights['G2'], biases['G2']]
In [10]:
LR = 0.0002
D_optm = tf.train.AdamOptimizer(LR).minimize(D_loss, var_list = D_var_list)
G_optm = tf.train.AdamOptimizer(LR).minimize(G_loss, var_list = G_var_list)
In [11]:
n_batch = 100
n_iter = 50000
n_prt = 5000

sess = tf.Session()
sess.run(tf.global_variables_initializer())

D_loss_record = []
G_loss_record = []
for epoch in range(n_iter):
    train_x, train_y = mnist.train.next_batch(n_batch)
    noise = make_noise(n_batch, n_G_input)

    # discriminator and generator are separately trained 
    sess.run(D_optm, feed_dict = {x: train_x, z: noise})
    sess.run(G_optm, feed_dict = {z: noise}) 
    
    if epoch % n_prt == 0:
        D_loss_val = sess.run(D_loss, feed_dict = {x: train_x, z: noise})
        G_loss_val = sess.run(G_loss, feed_dict = {z: noise})
        D_loss_record.append(D_loss_val)
        G_loss_record.append(G_loss_val)
    
        print('Epoch:', '%04d' % epoch, 'D_loss: {:.4}'.format(D_loss_val), 'G_loss: {:.4}'.format(G_loss_val))
        
        plt.figure(figsize = (10,5))
        plt.subplot(1,2,1)
        noise = make_noise(n_batch, n_G_input)
        G_img = sess.run(G_output, feed_dict = {z: noise})   
        plt.imshow(G_img[0,:].reshape(28,28), 'gray')
        plt.axis('off')
        plt.subplot(1,2,2)
        noise = make_noise(n_batch, n_G_input)
        G_img = sess.run(G_output, feed_dict = {z: noise})   
        plt.imshow(G_img[0,:].reshape(28,28), 'gray')
        plt.axis('off')
        plt.show()
Epoch: 0000 D_loss: 1.359 G_loss: 0.7156
Epoch: 5000 D_loss: 0.4242 G_loss: 2.241
Epoch: 10000 D_loss: 0.5035 G_loss: 1.846
Epoch: 15000 D_loss: 0.5597 G_loss: 2.073
Epoch: 20000 D_loss: 0.7506 G_loss: 1.768
Epoch: 25000 D_loss: 0.8272 G_loss: 1.679
Epoch: 30000 D_loss: 0.8976 G_loss: 1.393
Epoch: 35000 D_loss: 0.9354 G_loss: 1.303
Epoch: 40000 D_loss: 0.8676 G_loss: 1.54
Epoch: 45000 D_loss: 0.9473 G_loss: 1.485
In [33]:
noise = make_noise(n_batch, n_G_input)
G_img = sess.run(G_output, feed_dict = {z: noise})

plt.figure(figsize = (5,5))
plt.imshow(G_img[0,:].reshape(28,28), 'gray')
plt.axis('off')
plt.show()

4.2. After Training

  • After training, use the generator network to generate new data


5. Other Tutorials

In [13]:
%%html
<center><iframe src="https://www.youtube.com/embed/9JpdAg6uMXs?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
  • CS231n: CNN for Visual Recognition
In [14]:
%%html
<center><iframe src="https://www.youtube.com/embed/5WoItGTWV54?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>

MIT by Aaron Courville

In [15]:
%%html
<center><iframe src="https://www.youtube.com/embed/JVb54xhEw6Y?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>

Univ. of Wateloo By Ali Ghodsi

In [16]:
%%html
<center><iframe src="https://www.youtube.com/embed/7G4_Y5rsvi8?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
In [17]:
%%html
<center><iframe src="https://www.youtube.com/embed/odpjk7_tGY0?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
In [18]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')